from mxnet import nd, autograd, gluon
import mxnet as mx
from chapter5 import dataloader  # 自己刚建的读取数据的py文件
from data_set import filepaths  # 自己记录文件地址的py文件
from utils import mxnetUtils  # 自己的util库


class RESCAL(gluon.nn.Block):
    def __init__(self, n_entity, n_relation, dim=200, margin=1):
        super().__init__()
        self.margin = margin  # (式5-14)中的 m
        self.n_entity = n_entity  # 实体的数量
        self.n_relation = n_relation  # 关系的数量
        self.dim = dim  # embedding的长度
        # 随机初始化实体的embedding
        self.e = gluon.nn.Embedding(self.n_entity, dim)
        # 随机初始化关系矩阵的embedding
        self.r = gluon.nn.Embedding(self.n_relation, dim * dim)

    def batch_norm(self):
        for param in self.params:
            param = mxnetUtils.normlize(param)

    def net(self, X):
        x_correct, x_corrupt = X
        y_correct = self.predict(x_correct)
        y_corrupt = self.predict(x_corrupt)
        return self.__hinge_loss(y_correct, y_corrupt)

    def predict(self, x):
        h = self.e(x[:, 0])
        r = self.r(x[:, 1])
        t = self.e(x[:, 2])
        t = t.reshape(-1, self.dim, 1)
        r = r.reshape(-1, self.dim, self.dim)
        tr = nd.batch_dot(r, t)
        tr = tr.reshape(-1, self.dim)
        score = nd.sum(h * tr, -1)
        return -score

    def __hinge_loss(self, dist_correct, dist_corrupt):
        a = dist_correct - dist_corrupt + self.margin
        return nd.maximum(a, 0)


def train(net, dataLoad, pairs, epochs=20, lr=0.01, batchSize=1024):
    trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr})
    for e in range(epochs):
        l = 0
        for X in dataLoad.iter(pairs, batchSize):
            with autograd.record():
                loss = net.net(X)
            loss.backward()
            trainer.step(batchSize)
            net.batch_norm()
            l += sum(loss).asscalar()
        print("Epoch {}, average loss:{}".format(e, l / len(pairs)))


if __name__ == '__main__':
    entity, relationShips, pairs = dataloader.readData(filepaths.FB15K_237.TEST)
    net = RESCAL(len(entity), len(relationShips))
    net.collect_params().initialize(mx.init.Xavier())

    dataLoad = dataloader.DataIter(entity, relationShips)
    train(net, dataLoad, pairs)
